{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "upi2EY4L9ei3"
   },
   "outputs": [],
   "source": [
    "# Copyright 2024 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mbF2F2miAT4a"
   },
   "source": [
    "# Generate and store embeddings with batch processing\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/alloydb/notebooks/embeddings_batch_processing.ipynb)\n",
    "\n",
    "---\n",
    "## Introduction\n",
    "\n",
    "This notebook demonstrates an efficient way to generate and store vector embeddings in AlloyDB. You'll learn how to:\n",
    "\n",
    "* **Optimize embedding generation**: Dynamically batch text chunks based on character length to generate more embeddings with each API call.\n",
    "* **Streamline storage**: Use [Asyncio](https://docs.python.org/3/library/asyncio.html) to seamlessly update AlloyDB with the generated embeddings.\n",
    "\n",
    "This approach significantly speeds up the process, especially for large datasets, making it ideal for efficiently handling large-scale embedding tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FbcZUjT1yvTq"
   },
   "source": [
    "## What you'll need\n",
    "\n",
    "* A Google Cloud Account and Google Cloud Project"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uy9KqgPQ4GBi"
   },
   "source": [
    "## Basic Setup\n",
    "### Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M_ppDxYf4Gqs"
   },
   "outputs": [],
   "source": [
    "%pip install \\\n",
    "    google-cloud-alloydb-connector[asyncpg]==1.4.0 \\\n",
    "    sqlalchemy==2.0.36 \\\n",
    "    pandas==2.2.3 \\\n",
    "    vertexai==1.70.0 \\\n",
    "    asyncio==3.4.3 \\\n",
    "    greenlet==3.1.1 \\\n",
    "    --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Authenticate to Google Cloud within Colab\n",
    "If you're running this on google colab notebook, you will need to Authenticate as an IAM user."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.colab import auth\n",
    "\n",
    "auth.authenticate_user()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UCiNGP1Qxd6x"
   },
   "source": [
    "### Connect Your Google Cloud Project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SLUGlG6UE2CK",
    "outputId": "a284c046-00df-414a-9039-ddc5df12536d"
   },
   "outputs": [],
   "source": [
    "# @markdown Please fill in the value below with your GCP project ID and then run the cell.\n",
    "\n",
    "# Please fill in these values.\n",
    "project_id = \"my-project-id\"  # @param {type:\"string\"}\n",
    "\n",
    "# Quick input validations.\n",
    "assert project_id, \"⚠️ Please provide a Google Cloud project ID\"\n",
    "\n",
    "# Configure gcloud.\n",
    "!gcloud config set project {project_id}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "O-oqMC5Ox-ZM"
   },
   "source": [
    "### Enable APIs for AlloyDB and Vertex AI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X-bzfFb4A-xK"
   },
   "source": [
    "You will need to enable these APIs in order to create an AlloyDB database and utilize Vertex AI as an embeddings service!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "CKWrwyfzyTwH",
    "outputId": "f5131e77-2750-4cb1-b153-c52a13aaf284"
   },
   "outputs": [],
   "source": [
    "!gcloud services enable alloydb.googleapis.com aiplatform.googleapis.com"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Gn8g7-wCyZU6"
   },
   "source": [
    "## Set up AlloyDB\n",
    "You will need a Postgres AlloyDB instance for the following stages of this notebook. Please set the following variables to connect to your instance or create a new instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8q2lc-Po1mPv",
    "outputId": "e268aea8-0514-4308-f5c7-1916031255b7"
   },
   "outputs": [],
   "source": [
    "# @markdown Please fill in the both the Google Cloud region and name of your AlloyDB instance. Once filled in, run the cell.\n",
    "\n",
    "# Please fill in these values.\n",
    "region = \"us-central1\"  # @param {type:\"string\"}\n",
    "cluster_name = \"my-cluster\"  # @param {type:\"string\"}\n",
    "instance_name = \"my-primary\"  # @param {type:\"string\"}\n",
    "database_name = \"test_db\"  # @param {type:\"string\"}\n",
    "table_name = \"investments\"\n",
    "password = input(\"Please provide a password to be used for 'postgres' database user: \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XXI1uUu3y8gc"
   },
   "outputs": [],
   "source": [
    "# Quick input validations.\n",
    "assert region, \"⚠️ Please provide a Google Cloud region\"\n",
    "assert instance_name, \"⚠️ Please provide the name of your instance\"\n",
    "assert database_name, \"⚠️ Please provide the name of your database_name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T616pEOUygYQ"
   },
   "source": [
    "### Create an AlloyDB Instance\n",
    "If you have already created an AlloyDB Cluster and Instance, you can skip these steps and skip to the `Connect to AlloyDB` section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xyZYX4Jo1vfh"
   },
   "source": [
    "> ⏳ - Creating an AlloyDB cluster may take a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MQYni0NlTLzC",
    "outputId": "118d9a2b-2d9d-44ae-a33f-fb89ed6a2895"
   },
   "outputs": [],
   "source": [
    "!gcloud beta alloydb clusters create {cluster_name} --password={password} --region={region}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o8LkscYH5Vfp"
   },
   "source": [
    "Create an instance attached to our cluster with the following command.\n",
    "> ⏳ - Creating an AlloyDB instance may take a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "TkqQSWoY5Kab",
    "outputId": "78e02d10-5e14-457a-86c6-21348898bd0a"
   },
   "outputs": [],
   "source": [
    "!gcloud beta alloydb instances create {instance_name} --instance-type=PRIMARY --cpu-count=2 --region={region} --cluster={cluster_name}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BXsQ1UJv4ZVJ"
   },
   "source": [
    "To connect to your AlloyDB instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/alloydb/docs/connect-external) to connect to an AlloyDB for PostgreSQL instance with Private IP from outside your VPC."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OPVWsQB04Yyl",
    "outputId": "79f213ac-a069-4b15-e949-189f166dfca1"
   },
   "outputs": [],
   "source": [
    "!gcloud beta alloydb instances update {instance_name} --region={region} --cluster={cluster_name} --assign-inbound-public-ip=ASSIGN_IPV4 --database-flags=\"password.enforce_complexity=on\" --no-async"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_K86id-dcjcm"
   },
   "source": [
    "### Connect to AlloyDB\n",
    "\n",
    "This function will create a connection pool to your AlloyDB instance using the [AlloyDB Python connector](https://github.com/GoogleCloudPlatform/alloydb-python-connector). The AlloyDB Python connector will automatically create secure connections to your AlloyDB instance using mTLS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fYKVQzv2cjcm"
   },
   "outputs": [],
   "source": [
    "import asyncpg\n",
    "\n",
    "import sqlalchemy\n",
    "from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine\n",
    "\n",
    "from google.cloud.alloydb.connector import AsyncConnector, IPTypes\n",
    "\n",
    "async def init_connection_pool(connector: AsyncConnector, db_name: str, pool_size: int = 5) -> AsyncEngine:\n",
    "    # initialize Connector object for connections to AlloyDB\n",
    "    connection_string = f\"projects/{project_id}/locations/{region}/clusters/{cluster_name}/instances/{instance_name}\"\n",
    "\n",
    "    async def getconn() -> asyncpg.Connection:\n",
    "        conn: asyncpg.Connection = await connector.connect(\n",
    "            connection_string,\n",
    "            \"asyncpg\",\n",
    "            user=\"postgres\",\n",
    "            password=password,\n",
    "            db=db_name,\n",
    "            ip_type=IPTypes.PUBLIC,\n",
    "        )\n",
    "        return conn\n",
    "\n",
    "    pool = create_async_engine(\n",
    "        \"postgresql+asyncpg://\",\n",
    "        async_creator=getconn,\n",
    "        pool_size=pool_size,\n",
    "        max_overflow=0,\n",
    "    )\n",
    "    return pool\n",
    "\n",
    "connector = AsyncConnector()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i_yNN1MnJpTR"
   },
   "source": [
    "### Create a Database\n",
    "\n",
    "Next, you will create a database to store the data using the connection pool. Enabling public IP takes a few minutes, you may get an error that there is no public IP address. Please wait and retry this step if you hit an error!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7PX05ndo_AMc",
    "outputId": "0931754a-aeb8-4895-e0b5-eeb01ffe5506"
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import text, exc\n",
    "\n",
    "async def create_db(database_name, connector):    \n",
    "    pool = await init_connection_pool(connector, \"postgres\")\n",
    "    async with pool.connect() as conn:\n",
    "        try:\n",
    "          # End transaction. This ensures that a clean slate before creating a database.\n",
    "          await conn.execute(text(\"COMMIT\"))\n",
    "          await conn.execute(text(f\"CREATE DATABASE {database_name}\"))\n",
    "          print(f\"Database '{database_name}' created successfully\")\n",
    "        except exc.ProgrammingError as e:\n",
    "          print(e)\n",
    "\n",
    "await create_db(database_name=database_name, connector=connector)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HdolCWyatZmG"
   },
   "source": [
    "### Download data\n",
    "\n",
    "The following code has been prepared code to help insert the CSV data into your AlloyDB for PostgreSQL database."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Dzr-2VZIkvtY"
   },
   "source": [
    "Download the CSV file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5KkIQ2zSvQkN",
    "outputId": "f1980d73-4171-4fb1-b912-164187ba283b"
   },
   "outputs": [],
   "source": [
    "!gcloud storage cp gs://cloud-samples-data/alloydb/investments_data ./investments.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oFU13dCBlYHh"
   },
   "source": [
    "The download can be verified by the following command or using the \"Files\" tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nQBs10I8vShh",
    "outputId": "e81e933b-819d-46ac-f4de-6a1f943faa48"
   },
   "outputs": [],
   "source": [
    "!ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r16wPmxOBn_r"
   },
   "source": [
    "### Import data to your database\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this step you will:\n",
    "\n",
    "1. Create the table into store data\n",
    "2. And insert the data from the CSV into the database table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "v1pi9-8tB_pH"
   },
   "outputs": [],
   "source": [
    "# Prepare data\n",
    "import pandas as pd\n",
    "\n",
    "data = \"./investments.csv\"\n",
    "\n",
    "df = pd.read_csv(data)\n",
    "df['etf'] = df['etf'].map({'t': True, 'f': False})\n",
    "df['rating'] = df['rating'].astype(str).fillna('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 345
    },
    "id": "4R6tzuUtLypO",
    "outputId": "270d5fcd-b62d-4e3c-8c4e-25428798a350"
   },
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UstTWGJyL7j-"
   },
   "source": [
    "The data consists of the following columns:\n",
    "\n",
    "* **id**\n",
    "* **ticker**: A string representing the stock symbol or ticker (e.g., \"AAPL\" for Apple, \"GOOG\" for Google).\n",
    "* **etf**: A boolean value indicating whether the asset is an ETF (True) or not (False).\n",
    "* **market**:  A string representing the stock exchange where the asset is traded.\n",
    "* **rating**: Whether to hold, buy or sell a stock.\n",
    "* **overview**: A text field for a general overview or description of the asset.\n",
    "* **analysis**: A text field, for a more detailed analysis of the asset.\n",
    "* **overview_embedding** (empty)\n",
    "* **analysis_embedding** (empty)\n",
    "\n",
    "In this dataset, we need to embed two columns `overview` and `analysis`. The embeddings corresponding to these columns will be added to the `overview_embedding` and `analysis_embedding` column respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KqpLkwbWCJaw"
   },
   "outputs": [],
   "source": [
    "create_table_cmd = sqlalchemy.text(\n",
    "    f'CREATE TABLE {table_name} ( \\\n",
    "        id SERIAL PRIMARY KEY, \\\n",
    "        ticker VARCHAR(255) NOT NULL UNIQUE, \\\n",
    "        etf BOOLEAN, \\\n",
    "        market VARCHAR(255), \\\n",
    "        rating TEXT,  \\\n",
    "        overview TEXT, \\\n",
    "        overview_embedding VECTOR (768), \\\n",
    "        analysis TEXT,  \\\n",
    "        analysis_embedding VECTOR (768) \\\n",
    "    )'\n",
    ")\n",
    "\n",
    "\n",
    "insert_data_cmd = sqlalchemy.text(\n",
    "    f\"INSERT INTO {table_name} (id, ticker, etf, market, rating, overview, analysis)\\n\"\n",
    "    \"VALUES (:id, :ticker, :etf, :market, :rating, :overview, :analysis)\\n\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "qCsM2KXbdYiv"
   },
   "outputs": [],
   "source": [
    "from google.cloud.alloydb.connector import AsyncConnector\n",
    "\n",
    "# Create table and insert data\n",
    "async def insert_data(pool):\n",
    "  async with pool.connect() as db_conn:\n",
    "    await db_conn.execute(sqlalchemy.text(\"CREATE EXTENSION IF NOT EXISTS vector;\"))\n",
    "    await db_conn.execute(create_table_cmd)\n",
    "    await db_conn.execute(\n",
    "        insert_data_cmd,\n",
    "        df.to_dict('records'),\n",
    "    )\n",
    "    await db_conn.commit()\n",
    "\n",
    "pool = await init_connection_pool(connector, database_name)\n",
    "await insert_data(pool)\n",
    "await pool.dispose()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IaC8uhlfEwam"
   },
   "source": [
    "## Building an Embeddings Workflow\n",
    "\n",
    "Now that we have created our database, we'll define the methods to carry out each step of the embeddings workflow.\n",
    "\n",
    "The workflow contains four major steps:\n",
    "1. **Read the Data:** Load the dataset into our program.\n",
    "2. **Batch the Data:** Divide the data into smaller batches for efficient processing.\n",
    "3. **Generate Embeddings:** Use an embedding model to create vector representations of the text. The text to be embed could be present in multiple columns in the table.\n",
    "4. **Update Original Table:** Add the generated embeddings as new columns to our table."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oIk5GxbnFaE3"
   },
   "source": [
    "#### Step 0:  Configure Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wvYGGRRoFXl4"
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "# Configure the root logger to output messages with INFO level or above\n",
    "logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ekrEM22pJ2df"
   },
   "source": [
    "#### Step 1: Read the data\n",
    "\n",
    "This code reads data from a database and yields it for further processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IZgMik9XBW19"
   },
   "outputs": [],
   "source": [
    "from typing import AsyncIterator, List\n",
    "from sqlalchemy import RowMapping\n",
    "from sqlalchemy.ext.asyncio import AsyncEngine\n",
    "\n",
    "\n",
    "async def get_source_data(\n",
    "    pool: AsyncEngine, embed_cols: List[str]\n",
    ") -> AsyncIterator[RowMapping]:\n",
    "    \"\"\"Retrieves data from the database for embedding, excluding already embedded data.\n",
    "\n",
    "    Args:\n",
    "      pool: The AsyncEngine pool corresponding to the AlloyDB database.\n",
    "      embed_cols: A list of column names containing the data to be embedded.\n",
    "\n",
    "    Yields:\n",
    "      A single row of data, containing the 'id' and the specified `embed_cols`.\n",
    "      For example: {'id': 'id1', 'col1': 'val1', 'col2': 'val2'}\n",
    "    \"\"\"\n",
    "    logger = logging.getLogger(\"get_source_data\")\n",
    "\n",
    "    # Only embed columns which are not already embedded.\n",
    "    where_clause = \" OR \".join(f\"{col}_embedding IS NULL\" for col in embed_cols)\n",
    "    sql = f\"SELECT id, {', '.join(embed_cols)} FROM {table_name} WHERE {where_clause};\"\n",
    "    logger.info(f\"Running SQL query: {sql}\")\n",
    "\n",
    "    async with pool.connect() as conn:\n",
    "        async for row in await conn.stream(text(sql)):\n",
    "            logger.debug(f\"yielded row: {row._mapping['id']}\")\n",
    "            # Yield the row as a dictionary (RowMapping)\n",
    "            yield row._mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Kg54pvhjJ5kL"
   },
   "source": [
    "#### Step 2: Batch the data\n",
    "\n",
    "This code defines a function called `batch_source_data` that takes database rows and groups them into batches based on a character count limit (max_char_count). This batching process is crucial for efficient embedding generation for these reasons:\n",
    "\n",
    "* **Resource Optimization:**  Instead of sending numerous small requests, batching allows us to send fewer, larger requests. This significantly optimizes resource usage and potentially reduces API costs.\n",
    "\n",
    "* **Working Within API Limits:**  The max_char_count limit ensures each batch stays within the API's acceptable input size, preventing issues with exceeding the maximum character limit.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "76qq6G38CZfm"
   },
   "outputs": [],
   "source": [
    "from typing import Any, List\n",
    "import asyncio\n",
    "\n",
    "\n",
    "async def batch_source_data(\n",
    "    read_generator: AsyncIterator[RowMapping],\n",
    "    embed_cols: List[str],\n",
    ") -> AsyncIterator[List[dict[str, Any]]]:\n",
    "    \"\"\"\n",
    "    Groups data into batches for efficient embedding processing.\n",
    "\n",
    "    It is ensured that each batch adheres to predefined limits for character count\n",
    "    (`max_char_count`) and the number of embeddable instances (`max_instances_per_prediction`).\n",
    "\n",
    "    Args:\n",
    "      read_generator: An asynchronous generator yielding individual data rows.\n",
    "      embed_cols: A list of column names containing the data to be embedded.\n",
    "\n",
    "    Yields:\n",
    "      A list of rows, where each row contains data to be embedded.\n",
    "      For example:\n",
    "      [\n",
    "        {'id' : 'id1', 'col1': 'val1', 'col2': 'val2'},\n",
    "        ...\n",
    "      ]\n",
    "      where col1 and col2 are columns containing data to be embedded.\n",
    "    \"\"\"\n",
    "    logger = logging.getLogger(\"batch_data\")\n",
    "\n",
    "    global max_char_count\n",
    "    global max_instances_per_prediction\n",
    "\n",
    "    batch = []\n",
    "    batch_char_count = 0\n",
    "    batch_num = 0\n",
    "    batch_embed_cells = 0\n",
    "\n",
    "    async for row in read_generator:\n",
    "        # Char count in current row\n",
    "        row_char_count = 0\n",
    "        row_embed_cells = 0\n",
    "        for col in embed_cols:\n",
    "            if col in row and row[col] is not None:\n",
    "                row_char_count += len(row[col])\n",
    "                row_embed_cells += 1\n",
    "\n",
    "        # Skip the row if all columns to embed are empty.\n",
    "        if row_embed_cells == 0:\n",
    "            continue\n",
    "\n",
    "        # Ensure the batch doesn't exceed the maximum character count\n",
    "        # or the maximum number of embedding instances.\n",
    "        if (batch_char_count + row_char_count > max_char_count) or (\n",
    "            batch_embed_cells + row_embed_cells > max_instances_per_prediction\n",
    "        ):\n",
    "            batch_num += 1\n",
    "            logger.info(f\"yielded batch number: {batch_num} with length: {len(batch)}\")\n",
    "            yield batch\n",
    "            batch, batch_char_count, batch_embed_cells = [], 0, 0\n",
    "\n",
    "        # Add the current row to the batch\n",
    "        batch.append(row)\n",
    "        batch_char_count += row_char_count\n",
    "        batch_embed_cells += row_embed_cells\n",
    "\n",
    "    if batch:\n",
    "        batch_num += 1\n",
    "        logger.info(f\"Yielded batch number: {batch_num} with length: {len(batch)}\")\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_L4EnrleJ8gy"
   },
   "source": [
    "#### Step 3: Generate embeddings\n",
    "\n",
    "This step converts your text data into numerical representations called \"embeddings.\" These embeddings capture the meaning and relationships between words, making them useful for various tasks like search, recommendations, and clustering.\n",
    "\n",
    "The code uses two functions to efficiently generate embeddings:\n",
    "\n",
    "**embed_text**\n",
    "\n",
    "This function your text data and sends it to Vertex AI, transforming the text in specific columns into embeddings.\n",
    "\n",
    "**embed_objects_concurrently**\n",
    "\n",
    "This function is the orchestrator. It manages the embedding generation process for multiple batches of text concurrently. This function ensures that all batches are processed efficiently without overwhelming the system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4OYdrJk9Co0v"
   },
   "outputs": [],
   "source": [
    "from google.api_core.exceptions import ResourceExhausted\n",
    "from typing import Union\n",
    "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
    "\n",
    "\n",
    "async def embed_text(\n",
    "    batch_data: List[dict[str, Any]],\n",
    "    model: TextEmbeddingModel,\n",
    "    cols_to_embed: List[str],\n",
    "    task_type: str = \"SEMANTIC_SIMILARITY\",\n",
    "    retries: int = 100,\n",
    "    delay: int = 30,\n",
    ") -> List[dict[str, Union[List[float], str]]]:\n",
    "    \"\"\"Embeds text data from a batch of records using a Vertex AI embedding model.\n",
    "\n",
    "    Args:\n",
    "      batch_data: A data batch containing records with text data to embed.\n",
    "      model: The Vertex AI `TextEmbeddingModel` to use for generating embeddings.\n",
    "      cols_to_embed: A list of column names containing the data to be embedded.\n",
    "      task_type: The task type for the embedding model. Defaults to\n",
    "        \"SEMANTIC_SIMILARITY\".\n",
    "        Supported task types: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types\n",
    "      retries: The maximum number of times to retry embedding generation in case\n",
    "        of errors. Defaults to 100.\n",
    "      delay: The delay in seconds between retries. Defaults to 30.\n",
    "\n",
    "    Returns:\n",
    "      A list of records containing ids and embeddings.\n",
    "      Example:\n",
    "        [\n",
    "          {\n",
    "            'id': 'id1',\n",
    "            'col1_embedding': [1.0, 1.1, ...],\n",
    "            'col2_embedding': [2.0, 2.1, ...],\n",
    "            ...\n",
    "          },\n",
    "          ...\n",
    "        ]\n",
    "      where col1 and col2 are columns containing data to be embedded.\n",
    "    Raises:\n",
    "      Exception: Raises the encountered exception if all retries fail.\n",
    "    \"\"\"\n",
    "    logger = logging.getLogger(\"embed_objects\")\n",
    "    global total_char_count\n",
    "\n",
    "    # Place all of the embeddings into a single list\n",
    "    inputs = []\n",
    "    for row in batch_data:\n",
    "        for col in cols_to_embed:\n",
    "            if col in row and row[col]:\n",
    "                inputs.append(TextEmbeddingInput(row[col], task_type))\n",
    "\n",
    "    # Retry loop\n",
    "    for attempt in range(retries):\n",
    "        try:\n",
    "            # Get embeddings for the text data\n",
    "            embeddings = await model.get_embeddings_async(inputs)\n",
    "\n",
    "            # Increase total char count\n",
    "            total_char_count += sum([len(input.text) for input in inputs])\n",
    "\n",
    "            # group the results together by id\n",
    "            embedding_iter = iter(embeddings)\n",
    "            results = []\n",
    "            for row in batch_data:\n",
    "                r = {\"id\": row[\"id\"]}\n",
    "                for col in cols_to_embed:\n",
    "                    if col in row and row[col]:\n",
    "                        r[f\"{col}_embedding\"] = str(next(embedding_iter).values)\n",
    "                    else:\n",
    "                        r[f\"{col}_embedding\"] = None\n",
    "                results.append(r)\n",
    "            return results\n",
    "\n",
    "        except Exception as e:\n",
    "            if attempt < retries - 1:  # Retry only if attempts are left\n",
    "                logger.warning(f\"Error: {e}. Retrying in {delay} seconds...\")\n",
    "                await asyncio.sleep(delay)  # Wait before retrying\n",
    "            else:\n",
    "                logger.error(f\"Failed to get embeddings for data: {batch_data} after {retries} attempts.\")\n",
    "    return []\n",
    "\n",
    "\n",
    "async def embed_objects_concurrently(\n",
    "    cols_to_embed: List[str],\n",
    "    batch_data: AsyncIterator[List[dict[str, Any]]],\n",
    "    model: TextEmbeddingModel,\n",
    "    task_type: str,\n",
    "    max_concurrency: int = 5,\n",
    ") -> AsyncIterator[List[dict[str, Union[str, List[float]]]]]:\n",
    "    \"\"\"Embeds text data concurrently from an asynchronous batch data generator.\n",
    "\n",
    "    Args:\n",
    "      cols_to_embed: A list of column names containing the data to be embedded.\n",
    "      batch_data: A data batch containing records with text data to embed.\n",
    "      model: The Vertex AI `TextEmbeddingModel` to use for generating embeddings.\n",
    "      task_type: The task type for the embedding model.\n",
    "        Supported task types: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types\n",
    "      max_concurrency: The maximum number of embedding tasks to run concurrently.\n",
    "        Defaults to 5.\n",
    "    Yields:\n",
    "      A list of records containing ids and embeddings.\n",
    "    \"\"\"\n",
    "    logger = logging.getLogger(\"embed_objects\")\n",
    "\n",
    "    # Keep track of pending tasks\n",
    "    pending: set[asyncio.Task] = set()\n",
    "    has_next = True\n",
    "    while pending or has_next:\n",
    "        while len(pending) < max_concurrency and has_next:\n",
    "            try:\n",
    "                data = await batch_data.__anext__()\n",
    "                coro = embed_text(data, model, cols_to_embed, task_type)\n",
    "                pending.add(asyncio.ensure_future(coro))\n",
    "            except StopAsyncIteration:\n",
    "                has_next = False\n",
    "\n",
    "        if pending:\n",
    "            done, pending = await asyncio.wait(\n",
    "                pending, return_when=asyncio.FIRST_COMPLETED\n",
    "            )\n",
    "            for task in done:\n",
    "                result = task.result()\n",
    "                logger.info(f\"Embedding task completed: Processed {len(result)} rows.\")\n",
    "                yield result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FjErJPrJKA2j"
   },
   "source": [
    "#### Step 4: Update original table\n",
    "\n",
    "After generating embeddings for your text data, you need to store them in your database. This step efficiently updates your original table with the newly created embeddings.\n",
    "\n",
    "This process uses two functions to manage database updates:\n",
    "\n",
    "**batch_update_rows**\n",
    "1. This function takes a batch of data (including the embeddings) and updates the corresponding rows in your database table.\n",
    "2. It constructs an SQL UPDATE query to modify specific columns with the embedding values.\n",
    "3. It ensures that the updates are done efficiently and correctly within a database transaction.\n",
    "\n",
    "\n",
    "**batch_update_rows_concurrently**\n",
    "\n",
    "1. This function handles the concurrent updating of multiple batches of data.\n",
    "2. It creates multiple \"tasks\" that each execute the batch_update_rows function on a separate batch.\n",
    "3. It limits the number of concurrent tasks to avoid overloading your database and system resources.\n",
    "4. It manages the execution of these tasks, ensuring that all batches are processed efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lEyvhlOCCr7F"
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import text\n",
    "\n",
    "\n",
    "async def batch_update_rows(\n",
    "    pool: AsyncEngine, data: List[dict[str, Any]], cols_to_embed: List[str]\n",
    ") -> None:\n",
    "    \"\"\"Updates rows in the database with embedding data.\n",
    "\n",
    "    Args:\n",
    "      pool: The AsyncEngine pool corresponding to the AlloyDB database.\n",
    "      data: A data batch containing records with text embeddings.\n",
    "      cols_to_embed: A list of column names containing the data to be embedded.\n",
    "    \"\"\"\n",
    "    update_query = f\"\"\"\n",
    "    UPDATE {table_name}\n",
    "    SET {', '.join([f'{col}_embedding = :{col}_embedding' for col in cols_to_embed])}\n",
    "    WHERE id = :id;\n",
    "  \"\"\"\n",
    "    logger = logging.getLogger(\"update_rows\")\n",
    "    async with pool.connect() as conn:\n",
    "        await conn.execute(\n",
    "            text(update_query),\n",
    "            # Create parameters for all rows in the data\n",
    "            parameters=data,\n",
    "        )\n",
    "        await conn.commit()\n",
    "    logger.info(f\"Updated {len(data)} rows in database.\")\n",
    "\n",
    "\n",
    "async def batch_update_rows_concurrently(\n",
    "    pool: AsyncEngine,\n",
    "    embed_data: AsyncIterator[List[dict[str, Any]]],\n",
    "    cols_to_embed: List[str],\n",
    "    max_concurrency: int = 5,\n",
    "):\n",
    "    \"\"\"Updates database rows concurrently with embedding data.\n",
    "\n",
    "    Args:\n",
    "      pool: The AsyncEngine pool corresponding to the AlloyDB database.\n",
    "      embed_data: A data batch containing records with text embeddings.\n",
    "      cols_to_embed: A list of column names containing the data to be embedded.\n",
    "      max_concurrency: The maximum number of database update tasks to run concurrently.\n",
    "        Defaults to 5.\n",
    "    \"\"\"\n",
    "    logger = logging.getLogger(\"update_rows\")\n",
    "    # Keep track of pending tasks\n",
    "    pending: set[asyncio.Task] = set()\n",
    "    has_next = True\n",
    "    while pending or has_next:\n",
    "        while len(pending) < max_concurrency and has_next:\n",
    "            try:\n",
    "                data = await embed_data.__anext__()\n",
    "                coro = batch_update_rows(pool, data, cols_to_embed)\n",
    "                pending.add(asyncio.ensure_future(coro))\n",
    "            except StopAsyncIteration:\n",
    "                has_next = False\n",
    "        if pending:\n",
    "            done, pending = await asyncio.wait(\n",
    "                pending, return_when=asyncio.FIRST_COMPLETED\n",
    "            )\n",
    "\n",
    "    logger.info(\"All database update tasks completed.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HSv4DwzbJc5J"
   },
   "source": [
    "## Run the embeddings workflow\n",
    "\n",
    "This runs the complete embeddings workflow:\n",
    "\n",
    "1. Gettting source data\n",
    "2. Batching source data\n",
    "3. Generating embeddings for batches\n",
    "4. Updating data batches in the original table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "syO1Zq3o5PnI",
    "outputId": "8db5edfc-7b9e-46da-bda8-123444033b37"
   },
   "outputs": [],
   "source": [
    "import vertexai\n",
    "import time\n",
    "from vertexai.language_models import TextEmbeddingModel\n",
    "\n",
    "### Define variables ###\n",
    "\n",
    "# Max token count for the embeddings API\n",
    "max_tokens = 20000\n",
    "\n",
    "# For some tokenizers and text, there's a rough approximation that 1 token corresponds to about 3-4 characters.\n",
    "# This is a very general guideline and can vary significantly.\n",
    "max_char_count = max_tokens * 3\n",
    "max_instances_per_prediction = 250\n",
    "\n",
    "cols_to_embed = [\"analysis\", \"overview\"]\n",
    "\n",
    "# Model to use for generating embeddings\n",
    "model_name = \"text-embedding-004\"\n",
    "\n",
    "# Generate optimised embeddings for a given task\n",
    "# Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types\n",
    "task = \"SEMANTIC_SIMILARITY\"\n",
    "\n",
    "total_char_count = 0\n",
    "\n",
    "### Embeddings workflow ###\n",
    "\n",
    "\n",
    "async def run_embeddings_workflow(\n",
    "    pool_size: int = 10,\n",
    "    embed_data_concurrency: int = 20,\n",
    "    batch_update_concurrency: int = 10,\n",
    "):\n",
    "    \"\"\"Orchestrates the end-to-end workflow for generating and storing embeddings.\n",
    "\n",
    "    The workflow includes the following major steps:\n",
    "\n",
    "    1. Data Retrieval: Fetches data from the database that requires embedding.\n",
    "    2. Batching: Divides the data into batches for optimized processing.\n",
    "    3. Embedding Generation: Generates embeddings concurrently for the batched\n",
    "        data using the Vertex AI model.\n",
    "    4. Database Update: Updates the database concurrently with the generated\n",
    "        embeddings.\n",
    "\n",
    "    Args:\n",
    "        pool_size: The size of the database connection pool. Defaults to 10.\n",
    "        embed_data_concurrency: The maximum number of concurrent tasks for generating embeddings.\n",
    "            Defaults to 20.\n",
    "        batch_update_concurrency: The maximum number of concurrent tasks for updating the database.\n",
    "            Defaults to 10.\n",
    "    \"\"\"\n",
    "    # Set up connections to the database\n",
    "    pool = await init_connection_pool(connector, database_name, pool_size=pool_size)\n",
    "\n",
    "    # Initialise VertexAI and the model to be used to generate embeddings\n",
    "    vertexai.init(project=project_id, location=region)\n",
    "    model = TextEmbeddingModel.from_pretrained(model_name)\n",
    "\n",
    "    start_time = time.monotonic()\n",
    "\n",
    "    # Fetch source data from the database\n",
    "    source_data = get_source_data(pool, cols_to_embed)\n",
    "\n",
    "    # Divide the source data into batches for efficient processing\n",
    "    batch_data = batch_source_data(source_data, cols_to_embed)\n",
    "\n",
    "    # Generate embeddings for the batched data concurrently\n",
    "    embeddings_data = embed_objects_concurrently(\n",
    "        cols_to_embed, batch_data, model, task, max_concurrency=embed_data_concurrency\n",
    "    )\n",
    "\n",
    "    # Update the database with the generated embeddings concurrently\n",
    "    await batch_update_rows_concurrently(\n",
    "        pool, embeddings_data, cols_to_embed, max_concurrency=batch_update_concurrency\n",
    "    )\n",
    "\n",
    "    end_time = time.monotonic()\n",
    "    elapsed_time = end_time - start_time\n",
    "\n",
    "    # Release database connections and close the connector\n",
    "    await pool.dispose()\n",
    "    await connector.close()\n",
    "\n",
    "    print(f\"Job started at: {time.ctime(start_time)}\")\n",
    "    print(f\"Job ended at: {time.ctime(end_time)}\")\n",
    "    print(f\"Total run time: {elapsed_time:.2f} seconds\")\n",
    "    print(f\"Total characters embedded: {total_char_count}\")\n",
    "\n",
    "\n",
    "await run_embeddings_workflow()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "python-docs-samples",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
